{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "JAX BiT",
      "provenance": [],
      "collapsed_sections": [
        "zDMzHsUFsmRu",
        "kJHowEFBZnAN"
      ],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sXhZm0kpPpH6",
        "colab_type": "text"
      },
      "source": [
        "##### Copyright 2020 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KfmzfvFxPuk7",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iOVCm4CnP1Do",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_jax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aVW_4I33Pqoi"
      },
      "source": [
        "# BigTransfer (BiT): A step-by-step tutorial for state-of-the-art vision\n",
        "\n",
        "This colab demonstrates how to:\n",
        "1. Load BiT models in JAX.\n",
        "2. Make predictions using BiT pre-trained on CIFAR-10.\n",
        "3. Fine-tune BiT on 5-shot CIFAR-100 and get amazing results!\n",
        "\n",
        "It is good to get an understanding or quickly try things. However, to run longer training runs, we recommend using the commandline scripts at http://github.com/google-research/big_transfer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q9nosNf1PdEL",
        "colab_type": "text"
      },
      "source": [
        "# Install flax and run imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2MCUeLqhqh6d",
        "colab_type": "code",
        "outputId": "942ebcbd-88a2-4255-8ee2-341140949be7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 238
        }
      },
      "source": [
        "!pip install flax"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: flax in /usr/local/lib/python3.6/dist-packages (0.1.0)\n",
            "Requirement already satisfied: msgpack in /usr/local/lib/python3.6/dist-packages (from flax) (1.0.0)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from flax) (3.2.1)\n",
            "Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from flax) (1.18.4)\n",
            "Requirement already satisfied: jax>=0.1.59 in /usr/local/lib/python3.6/dist-packages (from flax) (0.1.64)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from flax) (0.7)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->flax) (0.10.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->flax) (1.2.0)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->flax) (2.4.7)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->flax) (2.8.1)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.59->flax) (0.9.0)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.59->flax) (3.2.1)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib->flax) (1.12.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MKHdxUTJRiUF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import io\n",
        "import re\n",
        "\n",
        "from functools import partial\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "\n",
        "import flax\n",
        "import flax.nn as nn\n",
        "import flax.optim as optim\n",
        "import flax.jax_utils as flax_utils\n",
        "\n",
        "# Assert that GPU is available\n",
        "assert 'Gpu' in str(jax.devices())\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mWzT0g8x6feo",
        "colab_type": "text"
      },
      "source": [
        "# Architecture and function for transforming BiT weights to JAX to format"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NjTie1S4RwmE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def fixed_padding(x, kernel_size):\n",
        "  pad_total = kernel_size - 1\n",
        "  pad_beg = pad_total // 2\n",
        "  pad_end = pad_total - pad_beg\n",
        "\n",
        "  x = jax.lax.pad(x, 0.0,\n",
        "                  ((0, 0, 0),\n",
        "                   (pad_beg, pad_end, 0), (pad_beg, pad_end, 0),\n",
        "                   (0, 0, 0)))\n",
        "  return x\n",
        "\n",
        "\n",
        "def standardize(x, axis, eps):\n",
        "  x = x - jnp.mean(x, axis=axis, keepdims=True)\n",
        "  x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)\n",
        "  return x\n",
        "\n",
        "\n",
        "class GroupNorm(nn.Module):\n",
        "  \"\"\"Group normalization (arxiv.org/abs/1803.08494).\"\"\"\n",
        "\n",
        "  def apply(self, x, num_groups=32):\n",
        "\n",
        "    input_shape = x.shape\n",
        "    group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups)\n",
        "\n",
        "    x = x.reshape(group_shape)\n",
        "\n",
        "    # Standardize along spatial and group dimensions\n",
        "    x = standardize(x, axis=[1, 2, 4], eps=1e-5)\n",
        "    x = x.reshape(input_shape)\n",
        "\n",
        "    bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])\n",
        "    x = x * self.param('scale', bias_scale_shape, nn.initializers.ones)\n",
        "    x = x + self.param('bias', bias_scale_shape, nn.initializers.zeros)\n",
        "    return x\n",
        "\n",
        "\n",
        "class StdConv(nn.Conv):\n",
        "\n",
        "  def param(self, name, shape, initializer):\n",
        "    param = super().param(name, shape, initializer)\n",
        "    if name == 'kernel':\n",
        "      param = standardize(param, axis=[0, 1, 2], eps=1e-10)\n",
        "    return param\n",
        "\n",
        "\n",
        "class RootBlock(nn.Module):\n",
        "\n",
        "  def apply(self, x, width):\n",
        "    x = fixed_padding(x, 7)\n",
        "    x = StdConv(x, width, (7, 7), (2, 2),\n",
        "                padding=\"VALID\",\n",
        "                bias=False,\n",
        "                name=\"conv_root\")\n",
        "\n",
        "    x = fixed_padding(x, 3)\n",
        "    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=\"VALID\")\n",
        "\n",
        "    return x\n",
        "\n",
        "\n",
        "class ResidualUnit(nn.Module):\n",
        "  \"\"\"Bottleneck ResNet block.\"\"\"\n",
        "\n",
        "  def apply(self, x, nout, strides=(1, 1)):\n",
        "    x_shortcut = x\n",
        "    needs_projection = x.shape[-1] != nout * 4 or strides != (1, 1)\n",
        "\n",
        "    group_norm = GroupNorm\n",
        "    conv = StdConv.partial(bias=False)\n",
        "\n",
        "    x = group_norm(x, name=\"gn1\")\n",
        "    x = nn.relu(x)\n",
        "    if needs_projection:\n",
        "      x_shortcut = conv(x, nout * 4, (1, 1), strides, name=\"conv_proj\")\n",
        "    x = conv(x, nout, (1, 1), name=\"conv1\")\n",
        "\n",
        "    x = group_norm(x, name=\"gn2\")\n",
        "    x = nn.relu(x)\n",
        "    x = fixed_padding(x, 3)\n",
        "    x = conv(x, nout, (3, 3), strides, name=\"conv2\", padding='VALID')\n",
        "\n",
        "    x = group_norm(x, name=\"gn3\")\n",
        "    x = nn.relu(x)\n",
        "    x = conv(x, nout * 4, (1, 1), name=\"conv3\")\n",
        "\n",
        "    return x + x_shortcut\n",
        "\n",
        "\n",
        "class ResidualBlock(nn.Module):\n",
        "\n",
        "  def apply(self, x, block_size, nout, first_stride):\n",
        "    x = ResidualUnit(\n",
        "        x, nout, strides=first_stride,\n",
        "        name=\"unit01\")\n",
        "    for i in range(1, block_size):\n",
        "      x = ResidualUnit(\n",
        "          x, nout, strides=(1, 1),\n",
        "          name=f\"unit{i+1:02d}\")\n",
        "    return x\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "  \"\"\"ResNetV2.\"\"\"\n",
        "\n",
        "  def apply(self, x, num_classes=1000,\n",
        "            width_factor=1, num_layers=50):\n",
        "    block_sizes = _block_sizes[num_layers]\n",
        "\n",
        "    width = 64 * width_factor\n",
        "\n",
        "    root_block = RootBlock.partial(width=width)\n",
        "    x = root_block(x, name='root_block')\n",
        "\n",
        "    # Blocks\n",
        "    for i, block_size in enumerate(block_sizes):\n",
        "      x = ResidualBlock(x, block_size, width * 2 ** i,\n",
        "                        first_stride=(1, 1) if i == 0 else (2, 2),\n",
        "                        name=f\"block{i + 1}\")\n",
        "\n",
        "    # Pre-head\n",
        "    x = GroupNorm(x, name='norm-pre-head')\n",
        "    x = nn.relu(x)\n",
        "    x = jnp.mean(x, axis=(1, 2))\n",
        "\n",
        "    # Head\n",
        "    x = nn.Dense(x, num_classes, name=\"conv_head\",\n",
        "                 kernel_init=nn.initializers.zeros)\n",
        "\n",
        "    return x.astype(jnp.float32)\n",
        "\n",
        "\n",
        "_block_sizes = {\n",
        "      50: [3, 4, 6, 3],\n",
        "      101: [3, 4, 23, 3],\n",
        "      152: [3, 8, 36, 3],\n",
        "  }\n",
        "\n",
        "\n",
        "def transform_params(params, params_tf, num_classes, init_head=False):\n",
        "  # BiT and JAX models have different naming conventions, so we need to\n",
        "  # properly map TF weights to JAX weights\n",
        "  params['root_block']['conv_root']['kernel'] = (\n",
        "    params_tf['resnet/root_block/standardized_conv2d/kernel'])\n",
        "\n",
        "  for block in ['block1', 'block2', 'block3', 'block4']:\n",
        "    units = set([re.findall(r'unit\\d+', p)[0] for p in params_tf.keys()\n",
        "                 if p.find(block) >= 0])\n",
        "    for unit in units:\n",
        "      for i, group in enumerate(['a', 'b', 'c']):\n",
        "        params[block][unit][f'conv{i+1}']['kernel'] = (\n",
        "          params_tf[f'resnet/{block}/{unit}/{group}/'\n",
        "                    'standardized_conv2d/kernel'])\n",
        "        params[block][unit][f'gn{i+1}']['bias'] = (\n",
        "          params_tf[f'resnet/{block}/{unit}/{group}/'\n",
        "                    'group_norm/beta'][None, None, None])\n",
        "        params[block][unit][f'gn{i+1}']['scale'] = (\n",
        "          params_tf[f'resnet/{block}/{unit}/{group}/'\n",
        "                    'group_norm/gamma'][None, None, None])\n",
        "\n",
        "      projs = [p for p in params_tf.keys()\n",
        "               if p.find(f'{block}/{unit}/a/proj') >= 0]\n",
        "      assert len(projs) <= 1\n",
        "      if projs:\n",
        "        params[block][unit]['conv_proj']['kernel'] = params_tf[projs[0]]\n",
        "\n",
        "  params['norm-pre-head']['bias'] = (\n",
        "    params_tf['resnet/group_norm/beta'][None, None, None])\n",
        "  params['norm-pre-head']['scale'] = (\n",
        "    params_tf['resnet/group_norm/gamma'][None, None, None])\n",
        "\n",
        "  if init_head:\n",
        "    params['conv_head']['kernel'] = params_tf['resnet/head/conv2d/kernel'][0, 0]\n",
        "    params['conv_head']['bias'] = params_tf['resnet/head/conv2d/bias']\n",
        "  else:\n",
        "    params['conv_head']['kernel'] = np.zeros(\n",
        "      (params['conv_head']['kernel'].shape[0], num_classes), dtype=np.float32)\n",
        "    params['conv_head']['bias'] = np.zeros(num_classes, dtype=np.float32)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b8NxYJ3QT5Nb",
        "colab_type": "text"
      },
      "source": [
        "# Run BiT-M-ResNet50x1 already fine-tuned on CIFAR-10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oqI1eMG6QDJv",
        "colab_type": "text"
      },
      "source": [
        "## Build model and load weights"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EGNzuKlDS6yU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "with tf.io.gfile.GFile('gs://bit_models/BiT-M-R50x1-CIFAR10.npz', 'rb') as f:\n",
        "  params_tf = np.load(f)\n",
        "params_tf = dict(zip(params_tf.keys(), params_tf.values()))\n",
        "\n",
        "for k in params_tf:\n",
        "  params_tf[k] = jnp.array(params_tf[k])\n",
        "\n",
        "ResNet_cifar10 = ResNet.partial(num_classes=10)\n",
        "\n",
        "def resnet_fn(params, images):\n",
        "  return ResNet_cifar10.partial(num_classes=10).call(params, images)\n",
        "\n",
        "resnet_init = ResNet_cifar10.init_by_shape\n",
        "_, params = resnet_init(jax.random.PRNGKey(0), [([1, 224, 224, 3], jnp.float32)])\n",
        "\n",
        "transform_params(params, params_tf, 10, init_head=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7plTTPh-QLtA",
        "colab_type": "text"
      },
      "source": [
        "## Prepare data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EB1fRfiRVlz_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "data_builder = tfds.builder('cifar10')\n",
        "data_builder.download_and_prepare()\n",
        "\n",
        "def _pp(data):\n",
        "  im = data['image']\n",
        "  im = tf.image.resize(im, [128, 128])\n",
        "  im = (im - 127.5) / 127.5\n",
        "  data['image'] = im\n",
        "  return {'image': data['image'], 'label': data['label']}\n",
        "\n",
        "data = data_builder.as_dataset(split='test')\n",
        "data = data.map(_pp)\n",
        "data = data.batch(100)\n",
        "data_iter = data.as_numpy_iterator()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y-jGg4cZU7HT",
        "colab_type": "text"
      },
      "source": [
        "## Run BiT"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "n9VdD7lqVl2V",
        "colab_type": "code",
        "outputId": "bca40452-e992-4da2-84ad-2a07c1f847ba",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "correct, n = 0, 0\n",
        "for batch in data_iter:\n",
        "  preds = resnet_fn(params, batch['image'])\n",
        "  correct += np.sum(np.argmax(preds, axis=1) == batch['label'])\n",
        "  n += len(preds)\n",
        "\n",
        "print(f\"CIFAR-10 accuracy of BiT-M-R50x1: {correct / n:0.3%}\")"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "CIFAR-10 accuracy of BiT-M-R50x1: 97.640%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ddQ5crM0Qzdi",
        "colab_type": "text"
      },
      "source": [
        "# Run finetuning on CIFAR-100"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lHwHCUDfUfLc",
        "colab_type": "text"
      },
      "source": [
        "## Prepare data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u4xd_uX1Vl5q",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "data_builder = tfds.builder('cifar100')\n",
        "data_builder.download_and_prepare()\n",
        "\n",
        "def get_data(split, repeats, batch_size, images_per_class, shuffle_buffer):\n",
        "  data = data_builder.as_dataset(split=split)\n",
        "\n",
        "  if split == 'train':\n",
        "    data = data.batch(50000)\n",
        "\n",
        "    data = data.as_numpy_iterator().next()\n",
        "\n",
        "    np.random.seed(0)\n",
        "    indices = [idx \n",
        "              for cls in range(100)\n",
        "              for idx in np.random.choice(np.where(data['label'] == cls)[0],\n",
        "                                          images_per_class,\n",
        "                                          replace=False)]\n",
        "\n",
        "    data = {'image': data['image'][indices],\n",
        "            'label': data['label'][indices]}\n",
        "\n",
        "    data = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(data['image']),\n",
        "                                tf.data.Dataset.from_tensor_slices(data['label'])))\n",
        "    data = data.map(lambda x, y: {'image': x, 'label': y})\n",
        "  else:\n",
        "    data = data.map(lambda d: {'image': d['image'], 'label': d['label']})\n",
        "\n",
        "  def _pp(data):\n",
        "    im = data['image']\n",
        "    if split == 'train':\n",
        "      im = tf.image.resize(im, [160, 160])\n",
        "      im = tf.image.random_crop(im, [128, 128, 3])\n",
        "      im = tf.image.flip_left_right(im)\n",
        "    else:\n",
        "      im = tf.image.resize(im, [128, 128])\n",
        "    im = (im - 127.5) / 127.5\n",
        "    data['image'] = im\n",
        "    data['label'] = tf.one_hot(data['label'], 100)\n",
        "    return {'image': data['image'], 'label': data['label']}\n",
        "\n",
        "  data = data.repeat(repeats)\n",
        "  data = data.shuffle(shuffle_buffer)\n",
        "  data = data.map(_pp)\n",
        "  return data.batch(batch_size)\n",
        "\n",
        "data_train = get_data(split='train', repeats=None, images_per_class=5,\n",
        "                      batch_size=64, shuffle_buffer=500)\n",
        "data_test = get_data(split='test', repeats=1, images_per_class=None,\n",
        "                      batch_size=250, shuffle_buffer=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "88htJnbRUhGU",
        "colab_type": "text"
      },
      "source": [
        "## Build model and load weights"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P5eKHLdLWv9n",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@jax.jit\n",
        "def resnet_fn(params, images):\n",
        "  return ResNet.partial(num_classes=100).call(params, images)\n",
        "\n",
        "def cross_entropy_loss(*, logits, labels):\n",
        "  logp = jax.nn.log_softmax(logits)\n",
        "  return -jnp.mean(jnp.sum(logp * labels, axis=1))\n",
        "\n",
        "def loss_fn(params, images, labels):\n",
        "  logits = resnet_fn(params, images)\n",
        "  return cross_entropy_loss(logits=logits, labels=labels)\n",
        "\n",
        "@jax.jit\n",
        "def update_fn(opt, lr, images, labels):\n",
        "  l, g = jax.value_and_grad(loss_fn)(opt.target, images, labels)\n",
        "  opt = opt.apply_gradient(g, learning_rate=lr)\n",
        "  return opt, l\n",
        "\n",
        "with tf.io.gfile.GFile('gs://bit_models/BiT-M-R50x1.npz', 'rb') as f:\n",
        "  params_tf = np.load(f)\n",
        "params_tf = dict(zip(params_tf.keys(), params_tf.values()))\n",
        "\n",
        "resnet_init = ResNet.partial(num_classes=100).init_by_shape\n",
        "_, params = resnet_init(jax.random.PRNGKey(0), [([1, 224, 224, 3], jnp.float32)])\n",
        "transform_params(params, params_tf, 100, init_head=False)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kfvUYwKTUlRc",
        "colab_type": "text"
      },
      "source": [
        "## Run optimization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4begAUoFWv65",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_lr(step):\n",
        "  lr = 0.003\n",
        "  if step < 100:\n",
        "    return lr * (step / 100)\n",
        "  else:\n",
        "    for s in [200, 300, 400]:\n",
        "      if s < step:\n",
        "        lr /= 10\n",
        "    return lr"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mDMxID3qWv4r",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        },
        "outputId": "26ffc369-aba8-4521-9638-92979301903d"
      },
      "source": [
        "opt = optim.Momentum(beta=0.9).create(params)\n",
        "\n",
        "for step, batch in zip(range(500), data_train.as_numpy_iterator()):\n",
        "\n",
        "    opt, loss_value = update_fn(\n",
        "        opt, get_lr(step), batch[\"image\"], batch[\"label\"])\n",
        "    \n",
        "    if opt.state.step % 100 == 0:\n",
        "      acc = np.mean([c for test_batch in data_test.as_numpy_iterator()\n",
        "                     for c in (np.argmax(test_batch['label'], axis=1) ==\n",
        "                               np.argmax(resnet_fn(opt.target, test_batch['image']), axis=1))])\n",
        "      print(f\"Step: {opt.state.step}, Test accuracy: {acc:0.3%}\")"
      ],
      "execution_count": 55,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step: 100, Test accuracy: 61.400%\n",
            "Step: 200, Test accuracy: 63.430%\n",
            "Step: 300, Test accuracy: 63.690%\n",
            "Step: 400, Test accuracy: 63.710%\n",
            "Step: 500, Test accuracy: 63.700%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZHydBZw4Wl2K",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}